import pathlib
import pickle

import numpy as np
from joblib import Parallel, delayed

from mcts import MCTS
from tree_env import SyntheticTree


def experiment(algorithm, tree, tau, alpha, number_of_atom):
    mcts = MCTS(exploration_coeff=exploration_coeff,
                algorithm=algorithm,
                tau=tau,
                alpha=alpha,
                number_of_atoms=number_of_atom,
                step_size=step_size,
                gamma=gamma,
                update_type='mean')

    v_hat, regret = mcts.run(tree, n_simulations)
    diff = np.abs(v_hat - tree.optimal_v_root)
    diff_uct = np.abs(v_hat - tree.max_mean)

    return diff, diff_uct, regret


n_exp = 5
n_trees = 5
n_simulations = 1000

# Set up the specific k,d combinations from the plot
k_d_combinations = [
    (16, 1),
    (200, 1),
    (14, 3),
    (16, 3),
    (16, 4),
    (200, 2)
]

# For heatmap compatibility (if needed)
k_heat = [2, 4, 6, 8, 10, 12, 14, 16, 100, 200]
d_heat = [1, 2, 3, 4]

exploration_coeff = 1.
tau = .1
gamma = 1.0
step_size = 0.2

# All algorithms to run
algorithms = {
    'uct': 'UCT',
    'power-uct': 'Power-UCT',
    'dng': 'DNG',
    'fixed-depth-mcts': 'Fixed-Depth-MCTS',
    'ments': 'MENTS',
    'rents': 'RENTS',
    'tents': 'TENTS',
    'dents': 'DENTS',
    'catso': 'CATSO',
    'patso': 'PATSO'
}

# Alpha values for power mean
alphas = [1, 2, 4, 8, 10, 16]

# Number of atoms for categorical methods
atoms = [10]

folder_name = './logs/expl_%.2f_tau_%.2f' % (exploration_coeff, tau)

# Process each k,d combination
for k, d in k_d_combinations:
    subfolder_name = folder_name + '/k_' + str(k) + '_d_' + str(d)
    pathlib.Path(subfolder_name).mkdir(parents=True, exist_ok=True)

    for number_of_atom in atoms:
        for alpha in alphas:
            for alg in algorithms.keys():
                # Skip conditions based on algorithm requirements
                if alg in {'uct', 'dng', 'fixed-depth-mcts', 'ments', 'tents', 'rents', 'dents'} and alpha > 1:
                    continue
                if alg in {'catso', 'patso'} and alpha != 10:
                    continue
                if alg in {'power-uct'} and alpha > 1:
                    continue

                # Special exploration coefficients for certain algorithms
                current_exploration_coeff = exploration_coeff
                current_tau = tau

                if alg in {'uct', 'fixed-depth-mcts'}:
                    current_exploration_coeff = 0.05
                elif alg == 'dents':
                    current_exploration_coeff = 0.75
                    current_tau = 0.5

                print('Branching factor: %d, Depth: %d, Alg: %s, Alpha: %f' % (k, d, alg, alpha))
                out = list()

                for w in range(n_trees):
                    try:
                        with open(subfolder_name + '/tree%d_%s_%f_%d.pkl' % (w, alg, alpha, number_of_atom), 'rb') as f:
                            tree = pickle.load(f)
                    except FileNotFoundError as err:
                        print('Tree not found! Creating new tree...')
                        tree = SyntheticTree(k, d, alg, current_tau, alpha, number_of_atom, gamma, step_size)
                        with open(subfolder_name + '/tree%d_%s_%f_%d.pkl' % (w, alg, alpha, number_of_atom), 'wb') as f:
                            pickle.dump(tree, f)

                    # Create MCTS instance with proper parameters
                    def run_single_experiment():
                        mcts_instance = MCTS(exploration_coeff=current_exploration_coeff,
                                             algorithm=alg,
                                             tau=current_tau,
                                             alpha=alpha,
                                             number_of_atoms=number_of_atom,
                                             step_size=step_size,
                                             gamma=gamma,
                                             update_type='mean')

                        v_hat, regret = mcts_instance.run(tree, n_simulations)
                        diff = np.abs(v_hat - tree.optimal_v_root)
                        diff_uct = np.abs(v_hat - tree.max_mean)
                        return diff, diff_uct, regret

                    # Run experiments in parallel
                    out += Parallel(n_jobs=-1)(
                        delayed(run_single_experiment)()
                        for _ in range(n_exp)
                    )

                out = np.array(out)

                diff = out[:, 0]
                diff_uct = out[:, 1]
                regret = out[:, 2]

                np.save(subfolder_name + '/diff_%s_%f_%d.npy' % (alg, alpha, number_of_atom), diff)
                np.save(subfolder_name + '/diff_uct_%s_%f_%d.npy' % (alg, alpha, number_of_atom), diff_uct)
                np.save(subfolder_name + '/regret_%s_%f_%d.npy' % (alg, alpha, number_of_atom), regret)

print("All experiments completed!")